{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from math import ceil\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import torch.nn as nn\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from utils.input_pipeline import get_image_folders\n",
    "from utils.training import train, optimization_step\n",
    "    \n",
    "torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "torch.backends.cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create data iterators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "100000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_folder, val_folder = get_image_folders()\n",
    "\n",
    "train_iterator = DataLoader(\n",
    "    train_folder, batch_size=batch_size, num_workers=4,\n",
    "    shuffle=True, pin_memory=True\n",
    ")\n",
    "\n",
    "val_iterator = DataLoader(\n",
    "    val_folder, batch_size=256, num_workers=4,\n",
    "    shuffle=False, pin_memory=True\n",
    ")\n",
    "\n",
    "# number of training samples\n",
    "train_size = len(train_folder.imgs)\n",
    "train_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# number of validation samples\n",
    "val_size = len(val_folder.imgs)\n",
    "val_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from densenet import DenseNet\n",
    "model = DenseNet()\n",
    "# load the model from step1\n",
    "model.load_state_dict(torch.load('model_step1.pytorch_state'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# create different parameter groups\n",
    "weights = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'conv' in n or 'classifier.weight' in n\n",
    "]\n",
    "biases = [model.classifier.bias]\n",
    "bn_weights = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'norm' in n and 'weight' in n\n",
    "]\n",
    "bn_biases = [\n",
    "    p for n, p in model.named_parameters()\n",
    "    if 'norm' in n and 'bias' in n\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "params = [\n",
    "    {'params': weights, 'weight_decay': 1e-4},\n",
    "    {'params': biases},\n",
    "    {'params': bn_weights},\n",
    "    {'params': bn_biases}\n",
    "]\n",
    "optimizer = optim.SGD(params, lr=1e-5, momentum=0.9, nesterov=True)\n",
    "loss = nn.CrossEntropyLoss().cuda()\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3125"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_epochs = 5\n",
    "n_batches = ceil(train_size/batch_size)\n",
    "\n",
    "# total number of batches in the train set\n",
    "n_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0  2.450 1.717  0.426 0.581  0.689 0.813  2330.438\n",
      "1  2.141 1.577  0.487 0.607  0.745 0.839  2312.572\n",
      "2  2.021 1.495  0.510 0.626  0.766 0.847  2309.705\n",
      "3  1.939 1.443  0.530 0.638  0.778 0.855  2307.756\n",
      "4  1.881 1.398  0.540 0.650  0.789 0.862  2305.410\n",
      "CPU times: user 2h 56min 54s, sys: 16min 15s, total: 3h 13min 10s\n",
      "Wall time: 3h 12min 45s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "def optimization_step_fn(model, loss, x_batch, y_batch):\n",
    "    return optimization_step(model, loss, x_batch, y_batch, optimizer)\n",
    "\n",
    "all_losses = train(\n",
    "    model, loss, optimization_step_fn,\n",
    "    train_iterator, val_iterator, n_epochs\n",
    ")\n",
    "# epoch logloss  accuracy    top5_accuracy time  (first value: train, second value: val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.cpu();\n",
    "torch.save(model.state_dict(), 'model_step2.pytorch_state')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
